import os 
import multiprocessing

def main():
    import argparse
    import torch
    from utils.str2bool import str2bool
    from utils.load_dataset import load_dataset
    from utils.instantiate_model import instantiate_model

    parser = argparse.ArgumentParser(description='Train', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    # Training parameters
    parser.add_argument('--model_dataset',          default='CIFAR10',      type=str,       help='Dataset to used to train the model')
    parser.add_argument('--dataset',                default='CIFAR10',      type=str,       help='Set dataset to use')
    parser.add_argument('--parallel',               default=False,          type=str2bool,  help='Device in  parallel')


    # Dataloader args
    parser.add_argument('--train_batch_size',       default=1024,            type=int,       help='Train batch size')
    parser.add_argument('--test_batch_size',        default=1024,            type=int,       help='Test batch size')
    parser.add_argument('--val_split',              default=0.1,            type=float,     help='Fraction of training dataset split as validation')
    parser.add_argument('--augment',                default=True,           type=str2bool,  help='Random horizontal flip and random crop')
    parser.add_argument('--padding_crop',           default=4,              type=int,       help='Padding for random crop')
    parser.add_argument('--shuffle',                default=True,           type=str2bool,  help='Shuffle the training dataset')
    parser.add_argument('--random_seed',            default=0,              type=int,       help='Initialising the seed for reproducibility')
    parser.add_argument('--arch',                   default='resnet18',     type=str,       help='Network architecture')
    parser.add_argument('--suffix',                 default='',             type=str,       help='Appended to model name')

    global args
    args = parser.parse_args()
    print(args)

    # Setup right device to run on
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


    # Use the following transform for training and testing
    print('\n')
    model_dataset = load_dataset(dataset=args.model_dataset,
                                 train_batch_size=args.train_batch_size,
                                 test_batch_size=args.test_batch_size,
                                 val_split=args.val_split,
                                 augment=args.augment,
                                 padding_crop=args.padding_crop,
                                 shuffle=args.shuffle,
                                 random_seed=args.random_seed,
                                 device=device)

    dataset = load_dataset(dataset=args.dataset,
                           train_batch_size=args.train_batch_size,
                           test_batch_size=args.test_batch_size,
                           val_split=args.val_split,
                           augment=args.augment,
                           padding_crop=args.padding_crop,
                           shuffle=args.shuffle,
                           random_seed=args.random_seed,
                           device=device,
                           mean=model_dataset.mean,
                           std=model_dataset.std)


    # Instantiate model 
    net, model_name = instantiate_model(dataset=model_dataset,
                                        arch=args.arch,
                                        suffix=args.suffix,
                                        load=True,
                                        torch_weights=False,
                                        device=device)

    net.eval()
    correct = 0
    total = 0
    dataset_len = len(dataset.test_loader.dataset)
    sample_count = torch.zeros([dataset.num_classes])

    # Count number of samples in each class
    for batch_idx, (_, labels) in enumerate(dataset.test_loader):
        one_hot = torch.nn.functional.one_hot(labels, dataset.num_classes)
        sample_count += one_hot.sum(axis=0)

    print("Class count on {}\n{}".format(dataset.name, sample_count))

    points_per_class = int(sample_count.min().item())
    classifier = list(net.children())[-1]
    hidden_rep = torch.zeros((dataset.num_classes, points_per_class, classifier.in_features))
    index = torch.zeros((dataset.num_classes)).long()

    with torch.no_grad():
        for batch_idx, (data, labels) in enumerate(dataset.test_loader):
            data = data.to(device)
            labels = labels.to(device)
            out, hidden = net(data, latent=True)
            _, pred = torch.max(out, dim=1)
            correct += (pred == labels).sum().item()
            total += labels.size()[0]

            for i, label in enumerate(labels.cpu().numpy()):
                if(index[label] >= points_per_class):
                    continue
                hidden_rep[label][index[label]] = hidden[i].cpu()
                index[label] += 1

    accuracy = float(correct) * 100.0 / float(total)

    if(args.model_dataset.lower() == args.dataset.lower()):
        torch.save(hidden_rep, './outputs/latent_space/{}_{}_{}.vec'.format(args.dataset.lower(), 
                                                            args.arch, 
                                                            args.suffix))
    else:
        torch.save(hidden_rep, './outputs/latent_space/{}_on_{}_{}_{}.vec'.format(args.dataset.lower(), 
                                                                args.model_dataset.lower(), 
                                                                args.arch, 
                                                                args.suffix))
    print("Accuracy {:.2f}".format(accuracy))

if __name__ == "__main__":
    if os.name == 'nt':
        # On Windows calling this function is necessary for multiprocessing
        multiprocessing.freeze_support()
    
    main()